{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import jieba\n",
    "import pandas as pd\n",
    "\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n",
    "\n",
    "# train_data = train_data.sample(frac=1.0)\n",
    "train_data[1], lbl = pd.factorize(train_data[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(lbl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>还有双鸭山到淮阴的汽车票吗13号的</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>从这里怎么回家</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>随便播放一首专辑阁楼里的佛里的歌</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>给看一下墓王之王嘛</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>我想看挑战两把s686打突变团竞的游戏视频</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       0  1\n",
       "0      还有双鸭山到淮阴的汽车票吗13号的  0\n",
       "1                从这里怎么回家  0\n",
       "2       随便播放一首专辑阁楼里的佛里的歌  1\n",
       "3              给看一下墓王之王嘛  2\n",
       "4  我想看挑战两把s686打突变团竞的游戏视频  3"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:08.358976Z",
     "start_time": "2021-03-11T09:35:08.347065Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:09.263768Z",
     "start_time": "2021-03-11T09:35:09.227357Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 划分为训练集和验证集\n",
    "# stratify 按照标签进行采样，训练集和验证部分同分布\n",
    "x_train, x_test, train_label, test_label = train_test_split(train_data[0].values,\n",
    "                                                            train_data[1].values,\n",
    "                                                            test_size=0.2,\n",
    "                                                            stratify=train_data[1].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: ylabel='Frequency'>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkQAAAGdCAYAAADzOWwgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlaUlEQVR4nO3df3BU9b3/8ddCsiFAdiFofpUA6QXBFMEheGHHH72YlFSiIxJnoKKkEm8vNnCBiAithfrjNhQGFAYFb7VEp1KEVmwlBUwDhLak/IgNAtYULTZ4k024xWRJan6QnO8ffDnjGq6FZbOb8Hk+ZnbGPeeTzXvPpM1zTs4eHJZlWQIAADBYr3APAAAAEG4EEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjRYR7gJ6go6ND1dXViomJkcPhCPc4AADgMliWpXPnzikpKUm9en35OSCC6DJUV1crOTk53GMAAIAAnD59WoMHD/7SNQTRZYiJiZF04YC6XK4wTwMAAC6Hz+dTcnKy/Xv8yxBEl+Hin8lcLhdBBABAD3M5l7twUTUAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIwXEe4BgFAZtqQo3CNcsY9XZIV7BAAwAmeIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYLyIcA+AnmnYkqJwjwAAQNBwhggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxuk0QrVixQg6HQwsWLLC3NTc3Ky8vT4MGDVL//v2VnZ2t2tpav6+rqqpSVlaW+vbtq7i4OD3++OM6f/6835p9+/Zp3LhxioqK0vDhw1VYWBiCdwQAAHqKbhFEhw8f1ksvvaQxY8b4bV+4cKHefvttbdu2TaWlpaqurta0adPs/e3t7crKylJra6sOHDigV199VYWFhVq2bJm95tSpU8rKytKkSZNUUVGhBQsW6JFHHtHu3btD9v4AAED3FvYgamxs1MyZM/WTn/xEAwcOtLc3NDTolVde0Zo1a3TnnXcqLS1NmzZt0oEDB/THP/5RkvTOO+/o/fff189+9jPdfPPNuuuuu/TMM8/ohRdeUGtrqyRp48aNSklJ0erVq3XjjTdq7ty5uv/++/Xcc8+F5f0CAIDuJ+xBlJeXp6ysLGVkZPhtLy8vV1tbm9/2UaNGaciQISorK5MklZWV6aabblJ8fLy9JjMzUz6fTydOnLDXfPG1MzMz7de4lJaWFvl8Pr8HAAC4dkWE85tv2bJF7777rg4fPtxpn9frldPp1IABA/y2x8fHy+v12ms+H0MX91/c92VrfD6fPvvsM0VHR3f63gUFBXrqqacCfl8AAKBnCdsZotOnT2v+/Pl6/fXX1adPn3CNcUlLly5VQ0OD/Th9+nS4RwIAAF0obEFUXl6uuro6jRs3ThEREYqIiFBpaanWrVuniIgIxcfHq7W1VfX19X5fV1tbq4SEBElSQkJCp0+dXXz+z9a4XK5Lnh2SpKioKLlcLr8HAAC4doUtiNLT03Xs2DFVVFTYj/Hjx2vmzJn2f0dGRqqkpMT+msrKSlVVVcnj8UiSPB6Pjh07prq6OntNcXGxXC6XUlNT7TWff42Lay6+BgAAQNiuIYqJidHo0aP9tvXr10+DBg2yt+fm5io/P1+xsbFyuVyaN2+ePB6PJk6cKEmaPHmyUlNT9dBDD2nlypXyer168sknlZeXp6ioKEnSnDlztH79ei1evFizZ8/Wnj17tHXrVhUVFYX2DQMAgG4rrBdV/zPPPfecevXqpezsbLW0tCgzM1Mvvviivb93797asWOHHn30UXk8HvXr1085OTl6+umn7TUpKSkqKirSwoULtXbtWg0ePFgvv/yyMjMzw/GWAABAN+SwLMsK9xDdnc/nk9vtVkNDA9cT/X/DlnCGLRQ+XpEV7hEAoMe6kt/fYb8PEQAAQLgRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwXliDaMOGDRozZoxcLpdcLpc8Ho927txp729ublZeXp4GDRqk/v37Kzs7W7W1tX6vUVVVpaysLPXt21dxcXF6/PHHdf78eb81+/bt07hx4xQVFaXhw4ersLAwFG8PAAD0EGENosGDB2vFihUqLy/XkSNHdOedd+ree+/ViRMnJEkLFy7U22+/rW3btqm0tFTV1dWaNm2a/fXt7e3KyspSa2urDhw4oFdffVWFhYVatmyZvebUqVPKysrSpEmTVFFRoQULFuiRRx7R7t27Q/5+AQBA9+SwLMsK9xCfFxsbq1WrVun+++/X9ddfr82bN+v++++XJH3wwQe68cYbVVZWpokTJ2rnzp26++67VV1drfj4eEnSxo0b9cQTT+jMmTNyOp164oknVFRUpOPHj9vfY8aMGaqvr9euXbsuayafzye3262Ghga5XK7gv+keaNiSonCPYISPV2SFewQA6LGu5Pd3t7mGqL29XVu2bFFTU5M8Ho/Ky8vV1tamjIwMe82oUaM0ZMgQlZWVSZLKysp000032TEkSZmZmfL5fPZZprKyMr/XuLjm4mtcSktLi3w+n98DAABcu8IeRMeOHVP//v0VFRWlOXPmaPv27UpNTZXX65XT6dSAAQP81sfHx8vr9UqSvF6vXwxd3H9x35et8fl8+uyzzy45U0FBgdxut/1ITk4OxlsFAADdVNiDaOTIkaqoqNDBgwf16KOPKicnR++//35YZ1q6dKkaGhrsx+nTp8M6DwAA6FoR4R7A6XRq+PDhkqS0tDQdPnxYa9eu1fTp09Xa2qr6+nq/s0S1tbVKSEiQJCUkJOjQoUN+r3fxU2ifX/PFT6bV1tbK5XIpOjr6kjNFRUUpKioqKO8PAAB0f2E/Q/RFHR0damlpUVpamiIjI1VSUmLvq6ysVFVVlTwejyTJ4/Ho2LFjqqurs9cUFxfL5XIpNTXVXvP517i45uJrAAAAhPUM0dKlS3XXXXdpyJAhOnfunDZv3qx9+/Zp9+7dcrvdys3NVX5+vmJjY+VyuTRv3jx5PB5NnDhRkjR58mSlpqbqoYce0sqVK+X1evXkk08qLy/PPsMzZ84crV+/XosXL9bs2bO1Z88ebd26VUVFfEoKAABcENYgqqur06xZs1RTUyO3260xY8Zo9+7d+sY3viFJeu6559SrVy9lZ2erpaVFmZmZevHFF+2v7927t3bs2KFHH31UHo9H/fr1U05Ojp5++ml7TUpKioqKirRw4UKtXbtWgwcP1ssvv6zMzMyQv18AANA9dbv7EHVH3IeoM+5DFBrchwgAAtcj70MEAAAQLgQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjBdQEP31r38N9hwAAABhE1AQDR8+XJMmTdLPfvYzNTc3B3smAACAkAooiN59912NGTNG+fn5SkhI0H/8x3/o0KFDwZ4NAAAgJAIKoptvvllr165VdXW1fvrTn6qmpka33XabRo8erTVr1ujMmTPBnhMAAKDLXNVF1REREZo2bZq2bdumH//4x/rwww+1aNEiJScna9asWaqpqQnWnAAAAF3mqoLoyJEj+u53v6vExEStWbNGixYt0kcffaTi4mJVV1fr3nvvDdacAAAAXSYikC9as2aNNm3apMrKSk2ZMkWvvfaapkyZol69LvRVSkqKCgsLNWzYsGDOCgAA0CUCCqINGzZo9uzZ+va3v63ExMRLromLi9Mrr7xyVcMBAACEQkBBdPLkyX+6xul0KicnJ5CXBwAACKmAriHatGmTtm3b1mn7tm3b9Oqrr171UAAAAKEUUBAVFBTouuuu67Q9Li5OP/rRj656KAAAgFAKKIiqqqqUkpLSafvQoUNVVVV11UMBAACEUkBBFBcXp/fee6/T9qNHj2rQoEFXPRQAAEAoBRRE3/rWt/Sf//mf2rt3r9rb29Xe3q49e/Zo/vz5mjFjRrBnBAAA6FIBfcrsmWee0ccff6z09HRFRFx4iY6ODs2aNYtriAAAQI8TUBA5nU698cYbeuaZZ3T06FFFR0frpptu0tChQ4M9HwAAQJcLKIguuuGGG3TDDTcEaxYAAICwCCiI2tvbVVhYqJKSEtXV1amjo8Nv/549e4IyHAAAQCgEFETz589XYWGhsrKyNHr0aDkcjmDPBQAAEDIBBdGWLVu0detWTZkyJdjzAAAAhFxAH7t3Op0aPnx4sGcBAAAIi4CC6LHHHtPatWtlWVaw5wEAAAi5gP5k9vvf/1579+7Vzp079bWvfU2RkZF++998882gDAcAABAKAQXRgAEDdN999wV7FgAAgLAIKIg2bdoU7DkAAADCJqBriCTp/Pnz+u1vf6uXXnpJ586dkyRVV1ersbExaMMBAACEQkBniP72t7/pm9/8pqqqqtTS0qJvfOMbiomJ0Y9//GO1tLRo48aNwZ4TAACgywR0hmj+/PkaP368Pv30U0VHR9vb77vvPpWUlARtOAAAgFAI6AzR7373Ox04cEBOp9Nv+7Bhw/Q///M/QRkMAAAgVAI6Q9TR0aH29vZO2z/55BPFxMRc9VAAAAChFFAQTZ48Wc8//7z93OFwqLGxUcuXL+ef8wAAAD1OQH8yW716tTIzM5Wamqrm5mY98MADOnnypK677jr9/Oc/D/aMAAAAXSqgIBo8eLCOHj2qLVu26L333lNjY6Nyc3M1c+ZMv4usAQAAeoKAgkiSIiIi9OCDDwZzFgAAgLAIKIhee+21L90/a9asgIYBAAAIh4CCaP78+X7P29ra9I9//ENOp1N9+/YliAAAQI8S0KfMPv30U79HY2OjKisrddttt3FRNQAA6HEC/rfMvmjEiBFasWJFp7NHAAAA3V3Qgki6cKF1dXV1MF8SAACgywV0DdGvf/1rv+eWZammpkbr16/XrbfeGpTBAAAAQiWgIJo6darfc4fDoeuvv1533nmnVq9eHYy5AAAAQiagIOro6Aj2HAAAAGET1GuIAAAAeqKAzhDl5+df9to1a9YE8i0AAABCJqAg+tOf/qQ//elPamtr08iRIyVJf/nLX9S7d2+NGzfOXudwOIIzJQAAQBcKKIjuuecexcTE6NVXX9XAgQMlXbhZ48MPP6zbb79djz32WFCHBAAA6EoBXUO0evVqFRQU2DEkSQMHDtSzzz7Lp8wAAECPE1AQ+Xw+nTlzptP2M2fO6Ny5c1c9FAAAQCgFFET33XefHn74Yb355pv65JNP9Mknn+iXv/ylcnNzNW3atGDPCAAA0KUCuoZo48aNWrRokR544AG1tbVdeKGICOXm5mrVqlVBHRAAAKCrBRREffv21YsvvqhVq1bpo48+kiT9y7/8i/r16xfU4QAAAELhqm7MWFNTo5qaGo0YMUL9+vWTZVnBmgsAACBkAgqiv//970pPT9cNN9ygKVOmqKamRpKUm5vLR+4BAECPE1AQLVy4UJGRkaqqqlLfvn3t7dOnT9euXbuCNhwAAEAoBHQN0TvvvKPdu3dr8ODBfttHjBihv/3tb0EZDAAAIFQCOkPU1NTkd2boorNnzyoqKuqqhwIAAAilgILo9ttv12uvvWY/dzgc6ujo0MqVKzVp0qTLfp2CggLdcsstiomJUVxcnKZOnarKykq/Nc3NzcrLy9OgQYPUv39/ZWdnq7a21m9NVVWVsrKy1LdvX8XFxenxxx/X+fPn/dbs27dP48aNU1RUlIYPH67CwsIrf+MAAOCaFNCfzFauXKn09HQdOXJEra2tWrx4sU6cOKGzZ8/qD3/4w2W/TmlpqfLy8nTLLbfo/Pnz+t73vqfJkyfr/ffftz/Cv3DhQhUVFWnbtm1yu92aO3eupk2bZn+f9vZ2ZWVlKSEhQQcOHFBNTY1mzZqlyMhI/ehHP5IknTp1SllZWZozZ45ef/11lZSU6JFHHlFiYqIyMzMDOQRASAxbUhTuEa7Yxyuywj0CAFwxhxXgZ+UbGhq0fv16HT16VI2NjRo3bpzy8vKUmJgY8DBnzpxRXFycSktLdccdd6ihoUHXX3+9Nm/erPvvv1+S9MEHH+jGG29UWVmZJk6cqJ07d+ruu+9WdXW14uPjJV24ceQTTzyhM2fOyOl06oknnlBRUZGOHz9uf68ZM2aovr7+si4C9/l8crvdamhokMvlCvj9XUt64i9qhAZBBKC7uJLf31f8J7O2tjalp6errq5O3//+97V161b95je/0bPPPntVMSRdiCxJio2NlSSVl5erra1NGRkZ9ppRo0ZpyJAhKisrkySVlZXppptusmNIkjIzM+Xz+XTixAl7zedf4+Kai68BAADMdsV/MouMjNR7770X9EE6Ojq0YMEC3XrrrRo9erQkyev1yul0asCAAX5r4+Pj5fV67TWfj6GL+y/u+7I1Pp9Pn332maKjo/32tbS0qKWlxX7u8/mu/g0CAIBuK6CLqh988EG98sorQR0kLy9Px48f15YtW4L6uoEoKCiQ2+22H8nJyeEeCQAAdKGALqo+f/68fvrTn+q3v/2t0tLSOv0bZmvWrLmi15s7d6527Nih/fv3+93bKCEhQa2traqvr/c7S1RbW6uEhAR7zaFDh/xe7+Kn0D6/5oufTKutrZXL5ep0dkiSli5dqvz8fPu5z+cjigAAuIZdURD99a9/1bBhw3T8+HGNGzdOkvSXv/zFb43D4bjs17MsS/PmzdP27du1b98+paSk+O1PS0tTZGSkSkpKlJ2dLUmqrKxUVVWVPB6PJMnj8ei//uu/VFdXp7i4OElScXGxXC6XUlNT7TW/+c1v/F67uLjYfo0vioqK4n5KAAAY5IqCaMSIEaqpqdHevXslXfinOtatW9fp+pzLlZeXp82bN+tXv/qVYmJi7Gt+3G63oqOj5Xa7lZubq/z8fMXGxsrlcmnevHnyeDyaOHGiJGny5MlKTU3VQw89pJUrV8rr9erJJ59UXl6eHTVz5szR+vXrtXjxYs2ePVt79uzR1q1bVVTEJ6UAAMAVXkP0xU/o79y5U01NTQF/8w0bNqihoUH/9m//psTERPvxxhtv2Guee+453X333crOztYdd9yhhIQEvfnmm/b+3r17a8eOHerdu7c8Ho8efPBBzZo1S08//bS9JiUlRUVFRSouLtbYsWO1evVqvfzyy9yDCAAASLrC+xD16tVLXq/X/tNUTEyMjh49qq9+9atdNmB3wH2IOuM+RPi/cB8iAN1Fl92HyOFwdLpG6EquGQIAAOiOrugaIsuy9O1vf9u+Nqe5uVlz5szp9Cmzz/9JCwAAoLu7oiDKycnxe/7ggw8GdRgAAIBwuKIg2rRpU1fNAQAAEDYB3akaAADgWkIQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwXliDaP/+/brnnnuUlJQkh8Oht956y2+/ZVlatmyZEhMTFR0drYyMDJ08edJvzdmzZzVz5ky5XC4NGDBAubm5amxs9Fvz3nvv6fbbb1efPn2UnJyslStXdvVbAwAAPUhYg6ipqUljx47VCy+8cMn9K1eu1Lp167Rx40YdPHhQ/fr1U2Zmppqbm+01M2fO1IkTJ1RcXKwdO3Zo//79+s53vmPv9/l8mjx5soYOHary8nKtWrVKP/zhD/Xf//3fXf7+AABAz+CwLMsK9xCS5HA4tH37dk2dOlXShbNDSUlJeuyxx7Ro0SJJUkNDg+Lj41VYWKgZM2boz3/+s1JTU3X48GGNHz9ekrRr1y5NmTJFn3zyiZKSkrRhwwZ9//vfl9frldPplCQtWbJEb731lj744IPLms3n88ntdquhoUEulyv4b74HGrakKNwjoJv6eEVWuEcAAElX9vu7215DdOrUKXm9XmVkZNjb3G63JkyYoLKyMklSWVmZBgwYYMeQJGVkZKhXr146ePCgveaOO+6wY0iSMjMzVVlZqU8//fSS37ulpUU+n8/vAQAArl3dNoi8Xq8kKT4+3m97fHy8vc/r9SouLs5vf0REhGJjY/3WXOo1Pv89vqigoEBut9t+JCcnX/0bAgAA3Va3DaJwWrp0qRoaGuzH6dOnwz0SAADoQt02iBISEiRJtbW1fttra2vtfQkJCaqrq/Pbf/78eZ09e9ZvzaVe4/Pf44uioqLkcrn8HgAA4NrVbYMoJSVFCQkJKikpsbf5fD4dPHhQHo9HkuTxeFRfX6/y8nJ7zZ49e9TR0aEJEybYa/bv36+2tjZ7TXFxsUaOHKmBAweG6N0AAIDuLKxB1NjYqIqKClVUVEi6cCF1RUWFqqqq5HA4tGDBAj377LP69a9/rWPHjmnWrFlKSkqyP4l244036pvf/Kb+/d//XYcOHdIf/vAHzZ07VzNmzFBSUpIk6YEHHpDT6VRubq5OnDihN954Q2vXrlV+fn6Y3jUAAOhuIsL5zY8cOaJJkybZzy9GSk5OjgoLC7V48WI1NTXpO9/5jurr63Xbbbdp165d6tOnj/01r7/+uubOnav09HT16tVL2dnZWrdunb3f7XbrnXfeUV5entLS0nTddddp2bJlfvcqAgAAZus29yHqzrgPUWfchwj/F+5DBKC7uCbuQwQAABAqBBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAONFhHsASMOWFIV7BAAAjMYZIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDxCCIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGC8i3AMAuLYMW1IU7hGu2McrssI9AoAw4wwRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAOMRRAAAwHgEEQAAMB5BBAAAjEcQAQAA4xFEAADAeAQRAAAwHkEEAACMRxABAADjEUQAAMB4BBEAADAeQQQAAIxHEAEAAONFhHsAAAi3YUuKwj3CFft4RVa4RwCuKZwhAgAAxiOIAACA8QgiAABgPKOC6IUXXtCwYcPUp08fTZgwQYcOHQr3SAAAoBswJojeeOMN5efna/ny5Xr33Xc1duxYZWZmqq6uLtyjAQCAMHNYlmWFe4hQmDBhgm655RatX79ektTR0aHk5GTNmzdPS5Ys+dKv9fl8crvdamhokMvlCvpsPfETLgAQCD4dh1C6kt/fRnzsvrW1VeXl5Vq6dKm9rVevXsrIyFBZWVmn9S0tLWppabGfNzQ0SLpwYLtCR8s/uuR1AaC76ar/HwUu5eLP2+Wc+zEiiP73f/9X7e3tio+P99seHx+vDz74oNP6goICPfXUU522Jycnd9mMAGAC9/PhngAmOnfunNxu95euMSKIrtTSpUuVn59vP+/o6NDZs2c1aNAgORyOME7Wc/h8PiUnJ+v06dNd8mdGXMBxDg2Oc+hwrEPDlONsWZbOnTunpKSkf7rWiCC67rrr1Lt3b9XW1vptr62tVUJCQqf1UVFRioqK8ts2YMCArhzxmuVyua7p/7F1Fxzn0OA4hw7HOjRMOM7/7MzQRUZ8yszpdCotLU0lJSX2to6ODpWUlMjj8YRxMgAA0B0YcYZIkvLz85WTk6Px48frX//1X/X888+rqalJDz/8cLhHAwAAYWZMEE2fPl1nzpzRsmXL5PV6dfPNN2vXrl2dLrRGcERFRWn58uWd/vSI4OI4hwbHOXQ41qHBce7MmPsQAQAA/F+MuIYIAADgyxBEAADAeAQRAAAwHkEEAACMRxDhquzfv1/33HOPkpKS5HA49NZbb/nttyxLy5YtU2JioqKjo5WRkaGTJ0+GZ9geqqCgQLfccotiYmIUFxenqVOnqrKy0m9Nc3Oz8vLyNGjQIPXv31/Z2dmdbkSKf27Dhg0aM2aMfbM6j8ejnTt32vs5zl1jxYoVcjgcWrBggb2NY331fvjDH8rhcPg9Ro0aZe/nGPsjiHBVmpqaNHbsWL3wwguX3L9y5UqtW7dOGzdu1MGDB9WvXz9lZmaqubk5xJP2XKWlpcrLy9Mf//hHFRcXq62tTZMnT1ZTU5O9ZuHChXr77be1bds2lZaWqrq6WtOmTQvj1D3T4MGDtWLFCpWXl+vIkSO68847de+99+rEiROSOM5d4fDhw3rppZc0ZswYv+0c6+D42te+ppqaGvvx+9//3t7HMf4CCwgSSdb27dvt5x0dHVZCQoK1atUqe1t9fb0VFRVl/fznPw/DhNeGuro6S5JVWlpqWdaFYxoZGWlt27bNXvPnP//ZkmSVlZWFa8xrxsCBA62XX36Z49wFzp07Z40YMcIqLi62vv71r1vz58+3LIuf6WBZvny5NXbs2Evu4xh3xhkidJlTp07J6/UqIyPD3uZ2uzVhwgSVlZWFcbKeraGhQZIUGxsrSSovL1dbW5vfcR41apSGDBnCcb4K7e3t2rJli5qamuTxeDjOXSAvL09ZWVl+x1TiZzqYTp48qaSkJH31q1/VzJkzVVVVJYljfCnG3Kkaoef1eiWp093A4+Pj7X24Mh0dHVqwYIFuvfVWjR49WtKF4+x0Ojv9A8Qc58AcO3ZMHo9Hzc3N6t+/v7Zv367U1FRVVFRwnINoy5Ytevfdd3X48OFO+/iZDo4JEyaosLBQI0eOVE1NjZ566indfvvtOn78OMf4EggioAfJy8vT8ePH/a4DQHCNHDlSFRUVamho0C9+8Qvl5OSotLQ03GNdU06fPq358+eruLhYffr0Cfc416y77rrL/u8xY8ZowoQJGjp0qLZu3aro6OgwTtY98SczdJmEhARJ6vSphdraWnsfLt/cuXO1Y8cO7d27V4MHD7a3JyQkqLW1VfX19X7rOc6BcTqdGj58uNLS0lRQUKCxY8dq7dq1HOcgKi8vV11dncaNG6eIiAhFRESotLRU69atU0REhOLj4znWXWDAgAG64YYb9OGHH/LzfAkEEbpMSkqKEhISVFJSYm/z+Xw6ePCgPB5PGCfrWSzL0ty5c7V9+3bt2bNHKSkpfvvT0tIUGRnpd5wrKytVVVXFcQ6Cjo4OtbS0cJyDKD09XceOHVNFRYX9GD9+vGbOnGn/N8c6+BobG/XRRx8pMTGRn+dL4E9muCqNjY368MMP7eenTp1SRUWFYmNjNWTIEC1YsEDPPvusRowYoZSUFP3gBz9QUlKSpk6dGr6he5i8vDxt3rxZv/rVrxQTE2P/fd/tdis6Olput1u5ubnKz89XbGysXC6X5s2bJ4/Ho4kTJ4Z5+p5l6dKluuuuuzRkyBCdO3dOmzdv1r59+7R7926OcxDFxMTY18Bd1K9fPw0aNMjezrG+eosWLdI999yjoUOHqrq6WsuXL1fv3r31rW99i5/nSwn3x9zQs+3du9eS1OmRk5NjWdaFj97/4Ac/sOLj462oqCgrPT3dqqysDO/QPcyljq8ka9OmTfaazz77zPrud79rDRw40Orbt6913333WTU1NeEbuoeaPXu2NXToUMvpdFrXX3+9lZ6ebr3zzjv2fo5z1/n8x+4ti2MdDNOnT7cSExMtp9NpfeUrX7GmT59uffjhh/Z+jrE/h2VZVphaDAAAoFvgGiIAAGA8gggAABiPIAIAAMYjiAAAgPEIIgAAYDyCCAAAGI8gAgAAxiOIAACA8QgiAABgPIIIAAAYjyACAADGI4gAAIDx/h//nBH27R1aGwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_data[0].apply(len).plot(kind='hist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# input_ids：字的编码\n",
    "# token_type_ids：标识是第一个句子还是第二个句子\n",
    "# attention_mask：标识是不是填充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:19.935035Z",
     "start_time": "2021-03-11T09:35:09.908638Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/home/lyz/hf-models/hfl/chinese-macbert-large\")\n",
    "\n",
    "train_encoding = tokenizer(list(x_train), truncation=True, padding=True, max_length=30)\n",
    "test_encoding = tokenizer(list(x_test), truncation=True, padding=True, max_length=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 4,927,500 || all params: 330,462,232 || trainable%: 1.491093239362978\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "PeftModelForSequenceClassification(\n",
       "  (base_model): LoraModel(\n",
       "    (model): BertForSequenceClassification(\n",
       "      (bert): BertModel(\n",
       "        (embeddings): BertEmbeddings(\n",
       "          (word_embeddings): Embedding(21128, 1024, padding_idx=0)\n",
       "          (position_embeddings): Embedding(512, 1024)\n",
       "          (token_type_embeddings): Embedding(2, 1024)\n",
       "          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (encoder): BertEncoder(\n",
       "          (layer): ModuleList(\n",
       "            (0-23): 24 x BertLayer(\n",
       "              (attention): BertAttention(\n",
       "                (self): BertSelfAttention(\n",
       "                  (query): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.01, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=1024, out_features=50, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=50, out_features=1024, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                  )\n",
       "                  (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                  (value): lora.Linear(\n",
       "                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                    (lora_dropout): ModuleDict(\n",
       "                      (default): Dropout(p=0.01, inplace=False)\n",
       "                    )\n",
       "                    (lora_A): ModuleDict(\n",
       "                      (default): Linear(in_features=1024, out_features=50, bias=False)\n",
       "                    )\n",
       "                    (lora_B): ModuleDict(\n",
       "                      (default): Linear(in_features=50, out_features=1024, bias=False)\n",
       "                    )\n",
       "                    (lora_embedding_A): ParameterDict()\n",
       "                    (lora_embedding_B): ParameterDict()\n",
       "                  )\n",
       "                  (dropout): Dropout(p=0.1, inplace=False)\n",
       "                )\n",
       "                (output): BertSelfOutput(\n",
       "                  (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                  (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "                  (dropout): Dropout(p=0.1, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (intermediate): BertIntermediate(\n",
       "                (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "                (intermediate_act_fn): GELUActivation()\n",
       "              )\n",
       "              (output): BertOutput(\n",
       "                (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "                (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (pooler): BertPooler(\n",
       "          (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (activation): Tanh()\n",
       "        )\n",
       "      )\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "      (classifier): ModulesToSaveWrapper(\n",
       "        (original_module): Linear(in_features=1024, out_features=12, bias=True)\n",
       "        (modules_to_save): ModuleDict(\n",
       "          (default): Linear(in_features=1024, out_features=12, bias=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from peft import LoraConfig, TaskType\n",
    "from peft import get_peft_model\n",
    "from transformers import BertForSequenceClassification\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "    task_type=TaskType.SEQ_CLS, r=50, lora_alpha=1, lora_dropout=0.01\n",
    ")\n",
    "model = BertForSequenceClassification.from_pretrained(\n",
    "    '/home/lyz/hf-models/hfl/chinese-macbert-large/', num_labels=12\n",
    ")\n",
    "model = get_peft_model(model, lora_config)\n",
    "model.print_trainable_parameters()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:43.578135Z",
     "start_time": "2021-03-11T09:35:43.571452Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 数据集读取\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    # 读取单个样本\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(int(self.labels[idx]))\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "\n",
    "train_dataset = NewsDataset(train_encoding, train_label)\n",
    "test_dataset = NewsDataset(test_encoding, test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([ 101, 3221,  679, 3221, 5543, 2828, 6825, 5330, 1196, 5326, 5330, 3064,\n",
       "          671,  678, 1450,  102,    0,    0,    0,    0,    0,    0,    0,    0,\n",
       "            0,    0,    0,    0,    0,    0]),\n",
       " 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0]),\n",
       " 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0]),\n",
       " 'labels': tensor(2)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:35:44.110121Z",
     "start_time": "2021-03-11T09:35:44.104871Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 精度计算\n",
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:58:49.027161Z",
     "start_time": "2021-03-11T09:58:45.317009Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
    "from transformers import AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-11T09:44:22.077501Z",
     "start_time": "2021-03-11T09:39:16.473609Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# 训练函数\n",
    "def train(model, train_loader, epoch):\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for batch in train_loader:\n",
    "        # 正向传播\n",
    "        optim.zero_grad()\n",
    "\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        total_train_loss += loss.item()\n",
    "\n",
    "        # 反向梯度信息\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "        # 参数更新\n",
    "        optim.step()\n",
    "        # scheduler.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100 == 0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "\n",
    "    print(\"Epoch: %d, Average training loss: %.4f\" % (epoch, total_train_loss/len(train_loader)))\n",
    "\n",
    "\n",
    "def validation(model, val_dataloader):\n",
    "    model.eval()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in val_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "\n",
    "    avg_val_accuracy = total_eval_accuracy / len(val_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\" % (total_eval_loss/len(val_dataloader)))\n",
    "    print(\"-------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 2420  2421  2422 ... 12097 12098 12099]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoth: 0, iter_num: 100, loss: 0.5732, 16.53%\n",
      "epoth: 0, iter_num: 200, loss: 0.0713, 33.06%\n",
      "epoth: 0, iter_num: 300, loss: 0.9274, 49.59%\n",
      "epoth: 0, iter_num: 400, loss: 0.0207, 66.12%\n",
      "epoth: 0, iter_num: 500, loss: 0.3288, 82.64%\n",
      "epoth: 0, iter_num: 600, loss: 0.0508, 99.17%\n",
      "Epoch: 0, Average training loss: 0.4703\n",
      "Accuracy: 0.9567\n",
      "Average testing loss: 0.1515\n",
      "-------------------------------\n",
      "epoth: 1, iter_num: 100, loss: 0.3391, 16.53%\n",
      "epoth: 1, iter_num: 200, loss: 0.0278, 33.06%\n",
      "epoth: 1, iter_num: 300, loss: 0.0591, 49.59%\n",
      "epoth: 1, iter_num: 400, loss: 0.3624, 66.12%\n",
      "epoth: 1, iter_num: 500, loss: 0.3643, 82.64%\n",
      "epoth: 1, iter_num: 600, loss: 0.6247, 99.17%\n",
      "Epoch: 1, Average training loss: 0.1747\n",
      "Accuracy: 0.9771\n",
      "Average testing loss: 0.0831\n",
      "-------------------------------\n",
      "epoth: 2, iter_num: 100, loss: 0.0594, 16.53%\n",
      "epoth: 2, iter_num: 200, loss: 0.0027, 33.06%\n",
      "epoth: 2, iter_num: 300, loss: 0.2754, 49.59%\n",
      "epoth: 2, iter_num: 400, loss: 0.0014, 66.12%\n",
      "epoth: 2, iter_num: 500, loss: 0.0024, 82.64%\n",
      "epoth: 2, iter_num: 600, loss: 0.3313, 99.17%\n",
      "Epoch: 2, Average training loss: 0.1035\n",
      "Accuracy: 0.9878\n",
      "Average testing loss: 0.0477\n",
      "-------------------------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0019, 16.53%\n",
      "epoth: 3, iter_num: 200, loss: 0.0095, 33.06%\n",
      "epoth: 3, iter_num: 300, loss: 0.0025, 49.59%\n",
      "epoth: 3, iter_num: 400, loss: 0.0014, 66.12%\n",
      "epoth: 3, iter_num: 500, loss: 0.0018, 82.64%\n",
      "epoth: 3, iter_num: 600, loss: 0.0093, 99.17%\n",
      "Epoch: 3, Average training loss: 0.0590\n",
      "Accuracy: 0.9945\n",
      "Average testing loss: 0.0215\n",
      "-------------------------------\n",
      "epoth: 4, iter_num: 100, loss: 0.4369, 16.53%\n",
      "epoth: 4, iter_num: 200, loss: 0.0008, 33.06%\n",
      "epoth: 4, iter_num: 300, loss: 0.0032, 49.59%\n",
      "epoth: 4, iter_num: 400, loss: 0.0014, 66.12%\n",
      "epoth: 4, iter_num: 500, loss: 0.0303, 82.64%\n",
      "epoth: 4, iter_num: 600, loss: 0.0028, 99.17%\n",
      "Epoch: 4, Average training loss: 0.0383\n",
      "Accuracy: 0.9966\n",
      "Average testing loss: 0.0164\n",
      "-------------------------------\n",
      "[    0     1     2 ... 12097 12098 12099]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoth: 0, iter_num: 100, loss: 0.7529, 16.53%\n",
      "epoth: 0, iter_num: 200, loss: 0.3124, 33.06%\n",
      "epoth: 0, iter_num: 300, loss: 0.0973, 49.59%\n",
      "epoth: 0, iter_num: 400, loss: 0.1018, 66.12%\n",
      "epoth: 0, iter_num: 500, loss: 0.0297, 82.64%\n",
      "epoth: 0, iter_num: 600, loss: 0.2207, 99.17%\n",
      "Epoch: 0, Average training loss: 0.5043\n",
      "Accuracy: 0.9549\n",
      "Average testing loss: 0.1536\n",
      "-------------------------------\n",
      "epoth: 1, iter_num: 100, loss: 0.0143, 16.53%\n",
      "epoth: 1, iter_num: 200, loss: 0.1158, 33.06%\n",
      "epoth: 1, iter_num: 300, loss: 0.2074, 49.59%\n",
      "epoth: 1, iter_num: 400, loss: 0.9149, 66.12%\n",
      "epoth: 1, iter_num: 500, loss: 0.2557, 82.64%\n",
      "epoth: 1, iter_num: 600, loss: 0.0915, 99.17%\n",
      "Epoch: 1, Average training loss: 0.1685\n",
      "Accuracy: 0.9783\n",
      "Average testing loss: 0.0822\n",
      "-------------------------------\n",
      "epoth: 2, iter_num: 100, loss: 0.0121, 16.53%\n",
      "epoth: 2, iter_num: 200, loss: 0.0015, 33.06%\n",
      "epoth: 2, iter_num: 300, loss: 0.0043, 49.59%\n",
      "epoth: 2, iter_num: 400, loss: 0.0266, 66.12%\n",
      "epoth: 2, iter_num: 500, loss: 0.3288, 82.64%\n",
      "epoth: 2, iter_num: 600, loss: 0.0061, 99.17%\n",
      "Epoch: 2, Average training loss: 0.0982\n",
      "Accuracy: 0.9881\n",
      "Average testing loss: 0.0446\n",
      "-------------------------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0015, 16.53%\n",
      "epoth: 3, iter_num: 200, loss: 0.0007, 33.06%\n",
      "epoth: 3, iter_num: 300, loss: 0.0989, 49.59%\n",
      "epoth: 3, iter_num: 400, loss: 0.0496, 66.12%\n",
      "epoth: 3, iter_num: 500, loss: 0.0012, 82.64%\n",
      "epoth: 3, iter_num: 600, loss: 0.0008, 99.17%\n",
      "Epoch: 3, Average training loss: 0.0629\n",
      "Accuracy: 0.9907\n",
      "Average testing loss: 0.0369\n",
      "-------------------------------\n",
      "epoth: 4, iter_num: 100, loss: 0.0009, 16.53%\n",
      "epoth: 4, iter_num: 200, loss: 0.0042, 33.06%\n",
      "epoth: 4, iter_num: 300, loss: 0.0102, 49.59%\n",
      "epoth: 4, iter_num: 400, loss: 0.0068, 66.12%\n",
      "epoth: 4, iter_num: 500, loss: 0.0007, 82.64%\n",
      "epoth: 4, iter_num: 600, loss: 0.0005, 99.17%\n",
      "Epoch: 4, Average training loss: 0.0417\n",
      "Accuracy: 0.9969\n",
      "Average testing loss: 0.0100\n",
      "-------------------------------\n",
      "[    0     1     2 ... 12097 12098 12099]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoth: 0, iter_num: 100, loss: 0.3347, 16.53%\n",
      "epoth: 0, iter_num: 200, loss: 0.4755, 33.06%\n",
      "epoth: 0, iter_num: 300, loss: 0.3466, 49.59%\n",
      "epoth: 0, iter_num: 400, loss: 0.5965, 66.12%\n",
      "epoth: 0, iter_num: 500, loss: 0.6385, 82.64%\n",
      "epoth: 0, iter_num: 600, loss: 0.3223, 99.17%\n",
      "Epoch: 0, Average training loss: 0.5012\n",
      "Accuracy: 0.9476\n",
      "Average testing loss: 0.1758\n",
      "-------------------------------\n",
      "epoth: 1, iter_num: 100, loss: 0.0254, 16.53%\n",
      "epoth: 1, iter_num: 200, loss: 0.0565, 33.06%\n",
      "epoth: 1, iter_num: 300, loss: 0.9993, 49.59%\n",
      "epoth: 1, iter_num: 400, loss: 0.3394, 66.12%\n",
      "epoth: 1, iter_num: 500, loss: 0.2983, 82.64%\n",
      "epoth: 1, iter_num: 600, loss: 0.0101, 99.17%\n",
      "Epoch: 1, Average training loss: 0.1668\n",
      "Accuracy: 0.9772\n",
      "Average testing loss: 0.0792\n",
      "-------------------------------\n",
      "epoth: 2, iter_num: 100, loss: 0.2335, 16.53%\n",
      "epoth: 2, iter_num: 200, loss: 0.0140, 33.06%\n",
      "epoth: 2, iter_num: 300, loss: 0.0057, 49.59%\n",
      "epoth: 2, iter_num: 400, loss: 0.1011, 66.12%\n",
      "epoth: 2, iter_num: 500, loss: 0.0102, 82.64%\n",
      "epoth: 2, iter_num: 600, loss: 0.0291, 99.17%\n",
      "Epoch: 2, Average training loss: 0.0895\n",
      "Accuracy: 0.9876\n",
      "Average testing loss: 0.0452\n",
      "-------------------------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0034, 16.53%\n",
      "epoth: 3, iter_num: 200, loss: 0.0134, 33.06%\n",
      "epoth: 3, iter_num: 300, loss: 0.0020, 49.59%\n",
      "epoth: 3, iter_num: 400, loss: 0.0009, 66.12%\n",
      "epoth: 3, iter_num: 500, loss: 0.0009, 82.64%\n",
      "epoth: 3, iter_num: 600, loss: 0.0006, 99.17%\n",
      "Epoch: 3, Average training loss: 0.0559\n",
      "Accuracy: 0.9946\n",
      "Average testing loss: 0.0222\n",
      "-------------------------------\n",
      "epoth: 4, iter_num: 100, loss: 0.0006, 16.53%\n",
      "epoth: 4, iter_num: 200, loss: 0.0006, 33.06%\n",
      "epoth: 4, iter_num: 300, loss: 0.0014, 49.59%\n",
      "epoth: 4, iter_num: 400, loss: 0.0010, 66.12%\n",
      "epoth: 4, iter_num: 500, loss: 0.0006, 82.64%\n",
      "epoth: 4, iter_num: 600, loss: 0.0005, 99.17%\n",
      "Epoch: 4, Average training loss: 0.0348\n",
      "Accuracy: 0.9944\n",
      "Average testing loss: 0.0264\n",
      "-------------------------------\n",
      "[    0     1     2 ... 12097 12098 12099]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoth: 0, iter_num: 100, loss: 0.3930, 16.53%\n",
      "epoth: 0, iter_num: 200, loss: 0.0485, 33.06%\n",
      "epoth: 0, iter_num: 300, loss: 0.4751, 49.59%\n",
      "epoth: 0, iter_num: 400, loss: 0.2490, 66.12%\n",
      "epoth: 0, iter_num: 500, loss: 0.0389, 82.64%\n",
      "epoth: 0, iter_num: 600, loss: 0.1470, 99.17%\n",
      "Epoch: 0, Average training loss: 0.4779\n",
      "Accuracy: 0.9593\n",
      "Average testing loss: 0.1492\n",
      "-------------------------------\n",
      "epoth: 1, iter_num: 100, loss: 0.0496, 16.53%\n",
      "epoth: 1, iter_num: 200, loss: 0.2660, 33.06%\n",
      "epoth: 1, iter_num: 300, loss: 0.0111, 49.59%\n",
      "epoth: 1, iter_num: 400, loss: 0.4816, 66.12%\n",
      "epoth: 1, iter_num: 500, loss: 0.5089, 82.64%\n",
      "epoth: 1, iter_num: 600, loss: 0.1522, 99.17%\n",
      "Epoch: 1, Average training loss: 0.1635\n",
      "Accuracy: 0.9787\n",
      "Average testing loss: 0.0708\n",
      "-------------------------------\n",
      "epoth: 2, iter_num: 100, loss: 0.0025, 16.53%\n",
      "epoth: 2, iter_num: 200, loss: 0.1081, 33.06%\n",
      "epoth: 2, iter_num: 300, loss: 0.0617, 49.59%\n",
      "epoth: 2, iter_num: 400, loss: 0.0068, 66.12%\n",
      "epoth: 2, iter_num: 500, loss: 0.0023, 82.64%\n",
      "epoth: 2, iter_num: 600, loss: 0.0575, 99.17%\n",
      "Epoch: 2, Average training loss: 0.0911\n",
      "Accuracy: 0.9873\n",
      "Average testing loss: 0.0490\n",
      "-------------------------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0019, 16.53%\n",
      "epoth: 3, iter_num: 200, loss: 0.3401, 33.06%\n",
      "epoth: 3, iter_num: 300, loss: 0.0008, 49.59%\n",
      "epoth: 3, iter_num: 400, loss: 0.3757, 66.12%\n",
      "epoth: 3, iter_num: 500, loss: 0.0009, 82.64%\n",
      "epoth: 3, iter_num: 600, loss: 0.0142, 99.17%\n",
      "Epoch: 3, Average training loss: 0.0546\n",
      "Accuracy: 0.9932\n",
      "Average testing loss: 0.0278\n",
      "-------------------------------\n",
      "epoth: 4, iter_num: 100, loss: 0.0016, 16.53%\n",
      "epoth: 4, iter_num: 200, loss: 0.0018, 33.06%\n",
      "epoth: 4, iter_num: 300, loss: 0.5075, 49.59%\n",
      "epoth: 4, iter_num: 400, loss: 0.0006, 66.12%\n",
      "epoth: 4, iter_num: 500, loss: 0.0007, 82.64%\n",
      "epoth: 4, iter_num: 600, loss: 0.0006, 99.17%\n",
      "Epoch: 4, Average training loss: 0.0314\n",
      "Accuracy: 0.9938\n",
      "Average testing loss: 0.0275\n",
      "-------------------------------\n",
      "[   0    1    2 ... 9677 9678 9679]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/lyz/hf-models/hfl/chinese-macbert-large/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoth: 0, iter_num: 100, loss: 0.1960, 16.53%\n",
      "epoth: 0, iter_num: 200, loss: 0.0984, 33.06%\n",
      "epoth: 0, iter_num: 300, loss: 0.5474, 49.59%\n",
      "epoth: 0, iter_num: 400, loss: 0.0244, 66.12%\n",
      "epoth: 0, iter_num: 500, loss: 0.0375, 82.64%\n",
      "epoth: 0, iter_num: 600, loss: 0.0128, 99.17%\n",
      "Epoch: 0, Average training loss: 0.3445\n",
      "Accuracy: 0.9829\n",
      "Average testing loss: 0.0671\n",
      "-------------------------------\n",
      "epoth: 1, iter_num: 100, loss: 0.2014, 16.53%\n",
      "epoth: 1, iter_num: 200, loss: 0.0067, 33.06%\n",
      "epoth: 1, iter_num: 300, loss: 0.1942, 49.59%\n",
      "epoth: 1, iter_num: 400, loss: 0.0022, 66.12%\n",
      "epoth: 1, iter_num: 500, loss: 0.0015, 82.64%\n",
      "epoth: 1, iter_num: 600, loss: 0.0033, 99.17%\n",
      "Epoch: 1, Average training loss: 0.0807\n",
      "Accuracy: 0.9911\n",
      "Average testing loss: 0.0360\n",
      "-------------------------------\n",
      "epoth: 2, iter_num: 100, loss: 0.0010, 16.53%\n",
      "epoth: 2, iter_num: 200, loss: 0.0401, 33.06%\n",
      "epoth: 2, iter_num: 300, loss: 0.0008, 49.59%\n",
      "epoth: 2, iter_num: 400, loss: 0.0008, 66.12%\n",
      "epoth: 2, iter_num: 500, loss: 0.0010, 82.64%\n",
      "epoth: 2, iter_num: 600, loss: 0.0482, 99.17%\n",
      "Epoch: 2, Average training loss: 0.0437\n",
      "Accuracy: 0.9970\n",
      "Average testing loss: 0.0128\n",
      "-------------------------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0006, 16.53%\n",
      "epoth: 3, iter_num: 200, loss: 0.0004, 33.06%\n",
      "epoth: 3, iter_num: 300, loss: 0.0031, 49.59%\n",
      "epoth: 3, iter_num: 400, loss: 0.0006, 66.12%\n",
      "epoth: 3, iter_num: 500, loss: 0.0184, 82.64%\n",
      "epoth: 3, iter_num: 600, loss: 0.0004, 99.17%\n",
      "Epoch: 3, Average training loss: 0.0172\n",
      "Accuracy: 0.9982\n",
      "Average testing loss: 0.0060\n",
      "-------------------------------\n",
      "epoth: 4, iter_num: 100, loss: 0.0024, 16.53%\n",
      "epoth: 4, iter_num: 200, loss: 0.0002, 33.06%\n",
      "epoth: 4, iter_num: 300, loss: 0.0003, 49.59%\n",
      "epoth: 4, iter_num: 400, loss: 0.0002, 66.12%\n",
      "epoth: 4, iter_num: 500, loss: 0.0004, 82.64%\n",
      "epoth: 4, iter_num: 600, loss: 0.0003, 99.17%\n",
      "Epoch: 4, Average training loss: 0.0153\n",
      "Accuracy: 0.9994\n",
      "Average testing loss: 0.0030\n",
      "-------------------------------\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import KFold\n",
    "kf = KFold(n_splits=5)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "fold = 0\n",
    "for train_idx, val_idx in kf.split(train_data[0].values, train_data[1].values,):\n",
    "    print(train_idx)\n",
    "    train_text = train_data[0].iloc[train_idx]\n",
    "    val_text = train_data[0].iloc[train_idx]\n",
    "\n",
    "    train_label = train_data[1].iloc[train_idx].values\n",
    "    val_label = train_data[1].iloc[train_idx].values\n",
    "\n",
    "    train_encoding = tokenizer(list(train_text), truncation=True, padding=True, max_length=30)\n",
    "    val_encoding = tokenizer(list(val_text), truncation=True, padding=True, max_length=30)\n",
    "\n",
    "    # 默认是没有数据扩增，文本默认是没有变换的操作\n",
    "    train_dataset = NewsDataset(train_encoding, train_label)\n",
    "    val_dataset = NewsDataset(val_encoding, val_label)\n",
    "\n",
    "    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "\n",
    "    # peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.01)\n",
    "    model = BertForSequenceClassification.from_pretrained(\n",
    "        '/home/lyz/hf-models/hfl/chinese-macbert-large/', num_labels=12\n",
    "    )\n",
    "    # model = get_peft_model(model, lora_config)\n",
    "    model.to(device)\n",
    "    \n",
    "    # 优化方法\n",
    "    optim = AdamW(model.parameters(), lr=1e-5)\n",
    "    total_steps = len(train_loader) * 1\n",
    "    scheduler = get_linear_schedule_with_warmup(optim,\n",
    "                                                num_warmup_steps=0,  # Default value in run_glue.py\n",
    "                                                num_training_steps=total_steps)\n",
    "\n",
    "    for epoch in range(5):\n",
    "        train(model, train_loader, epoch)\n",
    "        validation(model, val_dataloader)\n",
    "\n",
    "    torch.save(model.state_dict(), 'model_' + str(fold) + '.pt')\n",
    "\n",
    "    fold += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_encoding = tokenizer(list(test_data[0]), truncation=True, padding=True, max_length=30)\n",
    "test_dataset = NewsDataset(test_encoding, [0] * len(test_data))\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)\n",
    "\n",
    "def prediction(model, test_dataloader):\n",
    "    model.eval()\n",
    "    pred = []\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        pred.append(logits)\n",
    "        # pred += list(np.argmax(logits, axis=1).flatten())\n",
    "\n",
    "    return np.vstack(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = np.zeros((3000, 12))\n",
    "for path in ['model_0.pt', 'model_1.pt', 'model_2.pt', 'model_3.pt', 'model_4.pt']:\n",
    "    # model = AutoModelForSequenceClassification.from_pretrained('/home/lyz/huggingface模型下载/hfl/chinese-macbert-large/', num_labels=12)\n",
    "    # device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    # model.to(device)\n",
    "    model.load_state_dict(torch.load(path))\n",
    "    pred += prediction(model, test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.DataFrame({\n",
    "    'ID': range(1, len(test_data) + 1),\n",
    "    'Target': [lbl[x] for x in pred.argmax(1)],\n",
    "}).to_csv('nlp_submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
